Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle sequence_lens for GRU on CPU #2479

Merged
merged 18 commits into from
Sep 8, 2023
Merged

Conversation

chentong319
Copy link
Collaborator

@chentong319 chentong319 commented Sep 5, 2023

This PR is a quick fix for sequence_lens. According to the definition from PyTorch, padding value is added after a sequence reaches its sequence lens. This PR does not try to save the computation. I will try another PR to use scf.if so that all the RNN op can be handled and computation will be saved.
The output of my test case of GRU seems to conform with the PyTorch example.

module{
func.func @main_graph(%arg0: tensor<2x2x1xf32>, %arg1: tensor<1x3x1xf32>, %arg2 : tensor<1x3x1xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
  %lens = onnx.Constant dense<[2, 1]> : tensor<2xi32>
  %cst = "onnx.NoValue"() {value} : () -> none
  %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %lens, %cst) : ( tensor<2x2x1xf32>, tensor<1x3x1xf32>, tensor<1x3x1xf32>, none, tensor<2xi32>, none) -> (tensor<*xf32>, tensor<*xf32>)
 onnx.Return %Y, %Y_h : tensor<*xf32>, tensor<*xf32>
}
"onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

The output is

The 1st output output_1:[2x1x2x1xfloat32] is: 
 [[[[ 0.0011079 ]
   [-0.00399583]]]


 [[[-0.001489  ]
   [ 0.        ]]]] 

The 2nd output output_1:[1x2x1xfloat32] is: 
 [[[-0.001489]
  [ 0.      ]]]

Limitations: This PR does not save computation with the sequence_lens info. To do that, I can add a scf.if within the loop for sequence and batch. However, the existing implementation defines the loop nest for batch and hidden state together. Need some efforts to break the loop nest. It is doable. But priority?

Question: should the final result be modified according to the sequence_lens? For example, should the 2nd output be [[[-0.001489] [-0.00399583]]]? I did not find any specification for that.

Signed-off-by: chentong319 <[email protected]>
@chentong319 chentong319 marked this pull request as draft September 5, 2023 18:27
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
@chentong319 chentong319 changed the title Handel sequence_lens for GRU on CPU Handle sequence_lens for GRU on CPU Sep 5, 2023
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
@chentong319 chentong319 marked this pull request as ready for review September 6, 2023 00:10
@chentong319
Copy link
Collaborator Author

Another test case for the initialH.

module{
func.func @main_graph(%arg0: tensor<4x3x1xf32>, %arg1: tensor<1x6x1xf32>, %arg2 : tensor<1x6x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
  %lens = onnx.Constant dense<[2,3,1]> : tensor<3xi32>
  %initial = onnx.Constant dense<[[[0., 1.],[2.0, 3.0],[4.0, 5.0]]]> : tensor<1x3x2xf32>
  %cst = "onnx.NoValue"() {value} : () -> none
  %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %lens, %initial) : 
    ( tensor<4x3x1xf32>, tensor<1x6x1xf32>, tensor<1x6x2xf32>, none, tensor<3xi32>, tensor<1x3x2xf32>) 
    -> (tensor<*xf32>, tensor<*xf32>)
 onnx.Return %Y, %Y_h : tensor<*xf32>, tensor<*xf32>
}
"onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

The result:

The 1st output output_1:[4x1x3x2xfloat32] is: 
 [[[[1.2503355e-03 4.5813003e-01]
   [8.9796937e-01 1.2981267e+00]
   [1.6873405e+00 2.0353923e+00]]]


 [[[5.4413028e-04 2.1451449e-01]
   [4.2194879e-01 5.9896427e-01]
   [4.0000000e+00 5.0000000e+00]]]


 [[[0.0000000e+00 1.0000000e+00]
   [2.0093442e-01 2.8191486e-01]
   [4.0000000e+00 5.0000000e+00]]]


 [[[0.0000000e+00 1.0000000e+00]
   [2.0000000e+00 3.0000000e+00]
   [4.0000000e+00 5.0000000e+00]]]] 

The 2nd output output_1:[1x3x2xfloat32] is: 
 [[[0. 1.]
  [2. 3.]
  [4. 5.]]] 

Value cond = createMath.sge(
createMath.cast(sequenceUB.getType(), sequenceIV), sequenceUB);
nextHt = createMath.select(cond, /*padding*/ initial, nextHt);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we create a common function for this to avoid boilerplate? and we can call it in other ops like LSTM and RNN.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

@chentong319
Copy link
Collaborator Author

Now both the first and second output of GRU are the same as the torch GRU example.

Signed-off-by: chentong319 <[email protected]>
Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@chentong319 chentong319 merged commit e3a8a67 into onnx:main Sep 8, 2023
4 checks passed
@chentong319 chentong319 deleted the gru-seq-cpu branch September 8, 2023 00:03
@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #12569 [push] Handle sequence_lens for... started at 20:04

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #11562 [push] Handle sequence_lens for... started at 20:13

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #12557 [push] Handle sequence_lens for... started at 19:04

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #12557 [push] Handle sequence_lens for... passed after 1 hr 5 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #12569 [push] Handle sequence_lens for... passed after 1 hr 24 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #11562 [push] Handle sequence_lens for... passed after 1 hr 44 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants